{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "HW04.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zC5KwRyl6Flp"
      },
      "source": [
        "# Task description\n",
        "- Classify the speakers of given features.\n",
        "- Main goal: Learn how to use transformer.\n",
        "- Baselines:\n",
        "  - Easy: Run sample code and know how to use transformer.\n",
        "  - Medium: Know how to adjust parameters of transformer.\n",
        "  - Hard: Construct [conformer](https://arxiv.org/abs/2005.08100) which is a variety of transformer. \n",
        "\n",
        "- Other links\n",
        "  - Kaggle: [link](https://www.kaggle.com/t/859c9ca9ede14fdea841be627c412322)\n",
        "  - Slide: [link](https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/HW04/HW04.pdf)\n",
        "  - Data: [link](https://drive.google.com/file/d/1T0RPnu-Sg5eIPwQPfYysipfcz81MnsYe/view?usp=sharing)\n",
        "  - Video (Chinese): [link](https://www.youtube.com/watch?v=EPerg2UnGaI)\n",
        "  - Video (English): [link](https://www.youtube.com/watch?v=Gpz6AUvCak0)\n",
        "  - Solution for downloading dataset fail.: [link](https://drive.google.com/drive/folders/13T0Pa_WGgQxNkqZk781qhc5T9-zfh19e?usp=sharing)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TPDoreyypeJE"
      },
      "source": [
        "# Download dataset\n",
        "- Please follow [here](https://drive.google.com/drive/folders/13T0Pa_WGgQxNkqZk781qhc5T9-zfh19e?usp=sharing) to download data\n",
        "- Data is [here](https://drive.google.com/file/d/1gaFy8RaQVUEXo2n0peCBR5gYKCB-mNHc/view?usp=sharing)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QvpaILXnJIcw"
      },
      "source": [
        "!gdown --id 'paste your own data download link' --output Dataset.zip\n",
        "!unzip Dataset.zip"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v1gYr_aoNDue"
      },
      "source": [
        "# Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mz_NpuAipk3h"
      },
      "source": [
        "## Dataset\n",
        "- Original dataset is [Voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/).\n",
        "- The [license](https://creativecommons.org/licenses/by/4.0/) and [complete version](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/files/license.txt) of Voxceleb1.\n",
        "- We randomly select 600 speakers from Voxceleb1.\n",
        "- Then preprocess the raw waveforms into mel-spectrograms.\n",
        "\n",
        "- Args:\n",
        "  - data_dir: The path to the data directory.\n",
        "  - metadata_path: The path to the metadata.\n",
        "  - segment_len: The length of audio segment for training. \n",
        "- The architecture of data directory \\\\\n",
        "  - data directory \\\\\n",
        "  |---- metadata.json \\\\\n",
        "  |---- testdata.json \\\\\n",
        "  |---- mapping.json \\\\\n",
        "  |---- uttr-{random string}.pt \\\\\n",
        "\n",
        "- The information in metadata\n",
        "  - \"n_mels\": The dimention of mel-spectrogram.\n",
        "  - \"speakers\": A dictionary. \n",
        "    - Key: speaker ids.\n",
        "    - value: \"feature_path\" and \"mel_len\"\n",
        "\n",
        "\n",
        "For efficiency, we segment the mel-spectrograms into segments in the traing step."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cd7hoGhYtbXQ"
      },
      "source": [
        "import os\n",
        "import json\n",
        "import torch\n",
        "import random\n",
        "from pathlib import Path\n",
        "from torch.utils.data import Dataset\n",
        "from torch.nn.utils.rnn import pad_sequence\n",
        " \n",
        " \n",
        "class myDataset(Dataset):\n",
        "  def __init__(self, data_dir, segment_len=128):\n",
        "    self.data_dir = data_dir\n",
        "    self.segment_len = segment_len\n",
        " \n",
        "    # Load the mapping from speaker neme to their corresponding id. \n",
        "    mapping_path = Path(data_dir) / \"mapping.json\"\n",
        "    mapping = json.load(mapping_path.open())\n",
        "    self.speaker2id = mapping[\"speaker2id\"]\n",
        " \n",
        "    # Load metadata of training data.\n",
        "    metadata_path = Path(data_dir) / \"metadata.json\"\n",
        "    metadata = json.load(open(metadata_path))[\"speakers\"]\n",
        " \n",
        "    # Get the total number of speaker.\n",
        "    self.speaker_num = len(metadata.keys())\n",
        "    self.data = []\n",
        "    for speaker in metadata.keys():\n",
        "      for utterances in metadata[speaker]:\n",
        "        self.data.append([utterances[\"feature_path\"], self.speaker2id[speaker]])\n",
        " \n",
        "  def __len__(self):\n",
        "    return len(self.data)\n",
        " \n",
        "  def __getitem__(self, index):\n",
        "    feat_path, speaker = self.data[index]\n",
        "    # Load preprocessed mel-spectrogram.\n",
        "    mel = torch.load(os.path.join(self.data_dir, feat_path))\n",
        " \n",
        "    # Segmemt mel-spectrogram into \"segment_len\" frames.\n",
        "    if len(mel) > self.segment_len:\n",
        "      # Randomly get the starting point of the segment.\n",
        "      start = random.randint(0, len(mel) - self.segment_len)\n",
        "      # Get a segment with \"segment_len\" frames.\n",
        "      mel = torch.FloatTensor(mel[start:start+self.segment_len])\n",
        "    else:\n",
        "      mel = torch.FloatTensor(mel)\n",
        "    # Turn the speaker id into long for computing loss later.\n",
        "    speaker = torch.FloatTensor([speaker]).long()\n",
        "    return mel, speaker\n",
        " \n",
        "  def get_speaker_number(self):\n",
        "    return self.speaker_num"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mqJxjoi_NGnB"
      },
      "source": [
        "## Dataloader\n",
        "- Split dataset into training dataset(90%) and validation dataset(10%).\n",
        "- Create dataloader to iterate the data.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zuT1AuFENI8t"
      },
      "source": [
        "import torch\n",
        "from torch.utils.data import DataLoader, random_split\n",
        "from torch.nn.utils.rnn import pad_sequence\n",
        "\n",
        "\n",
        "def collate_batch(batch):\n",
        "  # Process features within a batch.\n",
        "  \"\"\"Collate a batch of data.\"\"\"\n",
        "  mel, speaker = zip(*batch)\n",
        "  # Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.\n",
        "  mel = pad_sequence(mel, batch_first=True, padding_value=-20)    # pad log 10^(-20) which is very small value.\n",
        "  # mel: (batch size, length, 40)\n",
        "  return mel, torch.FloatTensor(speaker).long()\n",
        "\n",
        "\n",
        "def get_dataloader(data_dir, batch_size, n_workers):\n",
        "  \"\"\"Generate dataloader\"\"\"\n",
        "  dataset = myDataset(data_dir)\n",
        "  speaker_num = dataset.get_speaker_number()\n",
        "  # Split dataset into training dataset and validation dataset\n",
        "  trainlen = int(0.9 * len(dataset))\n",
        "  lengths = [trainlen, len(dataset) - trainlen]\n",
        "  trainset, validset = random_split(dataset, lengths)\n",
        "\n",
        "  train_loader = DataLoader(\n",
        "    trainset,\n",
        "    batch_size=batch_size,\n",
        "    shuffle=True,\n",
        "    drop_last=True,\n",
        "    num_workers=n_workers,\n",
        "    pin_memory=True,\n",
        "    collate_fn=collate_batch,\n",
        "  )\n",
        "  valid_loader = DataLoader(\n",
        "    validset,\n",
        "    batch_size=batch_size,\n",
        "    num_workers=n_workers,\n",
        "    drop_last=True,\n",
        "    pin_memory=True,\n",
        "    collate_fn=collate_batch,\n",
        "  )\n",
        "\n",
        "  return train_loader, valid_loader, speaker_num\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X0x6eXiHpr4R"
      },
      "source": [
        "# Model\n",
        "- TransformerEncoderLayer:\n",
        "  - Base transformer encoder layer in [Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n",
        "  - Parameters:\n",
        "    - d_model: the number of expected features of the input (required).\n",
        "\n",
        "    - nhead: the number of heads of the multiheadattention models (required).\n",
        "\n",
        "    - dim_feedforward: the dimension of the feedforward network model (default=2048).\n",
        "\n",
        "    - dropout: the dropout value (default=0.1).\n",
        "\n",
        "    - activation: the activation function of intermediate layer, relu or gelu (default=relu).\n",
        "\n",
        "- TransformerEncoder:\n",
        "  - TransformerEncoder is a stack of N transformer encoder layers\n",
        "  - Parameters:\n",
        "    - encoder_layer: an instance of the TransformerEncoderLayer() class (required).\n",
        "\n",
        "    - num_layers: the number of sub-encoder-layers in the encoder (required).\n",
        "\n",
        "    - norm: the layer normalization component (optional)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SHX4eVj4tjtd"
      },
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "\n",
        "class Classifier(nn.Module):\n",
        "  def __init__(self, d_model=80, n_spks=600, dropout=0.1):\n",
        "    super().__init__()\n",
        "    # Project the dimension of features from that of input into d_model.\n",
        "    self.prenet = nn.Linear(40, d_model)\n",
        "    # TODO:\n",
        "    #   Change Transformer to Conformer.\n",
        "    #   https://arxiv.org/abs/2005.08100\n",
        "    self.encoder_layer = nn.TransformerEncoderLayer(\n",
        "      d_model=d_model, dim_feedforward=256, nhead=2\n",
        "    )\n",
        "    # self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)\n",
        "\n",
        "    # Project the the dimension of features from d_model into speaker nums.\n",
        "    self.pred_layer = nn.Sequential(\n",
        "      nn.Linear(d_model, d_model),\n",
        "      nn.ReLU(),\n",
        "      nn.Linear(d_model, n_spks),\n",
        "    )\n",
        "\n",
        "  def forward(self, mels):\n",
        "    \"\"\"\n",
        "    args:\n",
        "      mels: (batch size, length, 40)\n",
        "    return:\n",
        "      out: (batch size, n_spks)\n",
        "    \"\"\"\n",
        "    # out: (batch size, length, d_model)\n",
        "    out = self.prenet(mels)\n",
        "    # out: (length, batch size, d_model)\n",
        "    out = out.permute(1, 0, 2)\n",
        "    # The encoder layer expect features in the shape of (length, batch size, d_model).\n",
        "    out = self.encoder_layer(out)\n",
        "    # out: (batch size, length, d_model)\n",
        "    out = out.transpose(0, 1)\n",
        "    # mean pooling\n",
        "    stats = out.mean(dim=1)\n",
        "\n",
        "    # out: (batch, n_spks)\n",
        "    out = self.pred_layer(stats)\n",
        "    return out\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-__DolPGpvDZ"
      },
      "source": [
        "# Learning rate schedule\n",
        "- For transformer architecture, the design of learning rate schedule is different from that of CNN.\n",
        "- Previous works show that the warmup of learning rate is useful for training models with transformer architectures.\n",
        "- The warmup schedule\n",
        "  - Set learning rate to 0 in the beginning.\n",
        "  - The learning rate increases linearly from 0 to initial learning rate during warmup period."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "K-0816BntqT9"
      },
      "source": [
        "import math\n",
        "\n",
        "import torch\n",
        "from torch.optim import Optimizer\n",
        "from torch.optim.lr_scheduler import LambdaLR\n",
        "\n",
        "\n",
        "def get_cosine_schedule_with_warmup(\n",
        "  optimizer: Optimizer,\n",
        "  num_warmup_steps: int,\n",
        "  num_training_steps: int,\n",
        "  num_cycles: float = 0.5,\n",
        "  last_epoch: int = -1,\n",
        "):\n",
        "  \"\"\"\n",
        "  Create a schedule with a learning rate that decreases following the values of the cosine function between the\n",
        "  initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n",
        "  initial lr set in the optimizer.\n",
        "\n",
        "  Args:\n",
        "    optimizer (:class:`~torch.optim.Optimizer`):\n",
        "      The optimizer for which to schedule the learning rate.\n",
        "    num_warmup_steps (:obj:`int`):\n",
        "      The number of steps for the warmup phase.\n",
        "    num_training_steps (:obj:`int`):\n",
        "      The total number of training steps.\n",
        "    num_cycles (:obj:`float`, `optional`, defaults to 0.5):\n",
        "      The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n",
        "      following a half-cosine).\n",
        "    last_epoch (:obj:`int`, `optional`, defaults to -1):\n",
        "      The index of the last epoch when resuming training.\n",
        "\n",
        "  Return:\n",
        "    :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n",
        "  \"\"\"\n",
        "\n",
        "  def lr_lambda(current_step):\n",
        "    # Warmup\n",
        "    if current_step < num_warmup_steps:\n",
        "      return float(current_step) / float(max(1, num_warmup_steps))\n",
        "    # decadence\n",
        "    progress = float(current_step - num_warmup_steps) / float(\n",
        "      max(1, num_training_steps - num_warmup_steps)\n",
        "    )\n",
        "    return max(\n",
        "      0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n",
        "    )\n",
        "\n",
        "  return LambdaLR(optimizer, lr_lambda, last_epoch)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IP03FFo9K8DS"
      },
      "source": [
        "# Model Function\n",
        "- Model forward function."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fohaLEFJK9-t"
      },
      "source": [
        "import torch\n",
        "\n",
        "\n",
        "def model_fn(batch, model, criterion, device):\n",
        "  \"\"\"Forward a batch through the model.\"\"\"\n",
        "\n",
        "  mels, labels = batch\n",
        "  mels = mels.to(device)\n",
        "  labels = labels.to(device)\n",
        "\n",
        "  outs = model(mels)\n",
        "\n",
        "  loss = criterion(outs, labels)\n",
        "\n",
        "  # Get the speaker id with highest probability.\n",
        "  preds = outs.argmax(1)\n",
        "  # Compute accuracy.\n",
        "  accuracy = torch.mean((preds == labels).float())\n",
        "\n",
        "  return loss, accuracy\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F7cg-YrzLQcf"
      },
      "source": [
        "# Validate\n",
        "- Calculate accuracy of the validation set."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mD-_p6nWLO2L"
      },
      "source": [
        "from tqdm import tqdm\n",
        "import torch\n",
        "\n",
        "\n",
        "def valid(dataloader, model, criterion, device): \n",
        "  \"\"\"Validate on validation set.\"\"\"\n",
        "\n",
        "  model.eval()\n",
        "  running_loss = 0.0\n",
        "  running_accuracy = 0.0\n",
        "  pbar = tqdm(total=len(dataloader.dataset), ncols=0, desc=\"Valid\", unit=\" uttr\")\n",
        "\n",
        "  for i, batch in enumerate(dataloader):\n",
        "    with torch.no_grad():\n",
        "      loss, accuracy = model_fn(batch, model, criterion, device)\n",
        "      running_loss += loss.item()\n",
        "      running_accuracy += accuracy.item()\n",
        "\n",
        "    pbar.update(dataloader.batch_size)\n",
        "    pbar.set_postfix(\n",
        "      loss=f\"{running_loss / (i+1):.2f}\",\n",
        "      accuracy=f\"{running_accuracy / (i+1):.2f}\",\n",
        "    )\n",
        "\n",
        "  pbar.close()\n",
        "  model.train()\n",
        "\n",
        "  return running_accuracy / len(dataloader)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "noHXyal5p1W5"
      },
      "source": [
        "# Main function"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "chRQE7oYtw62"
      },
      "source": [
        "from tqdm import tqdm\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.optim import AdamW\n",
        "from torch.utils.data import DataLoader, random_split\n",
        "\n",
        "\n",
        "def parse_args():\n",
        "  \"\"\"arguments\"\"\"\n",
        "  config = {\n",
        "    \"data_dir\": \"./Dataset\",\n",
        "    \"save_path\": \"model.ckpt\",\n",
        "    \"batch_size\": 32,\n",
        "    \"n_workers\": 8,\n",
        "    \"valid_steps\": 2000,\n",
        "    \"warmup_steps\": 1000,\n",
        "    \"save_steps\": 10000,\n",
        "    \"total_steps\": 70000,\n",
        "  }\n",
        "\n",
        "  return config\n",
        "\n",
        "\n",
        "def main(\n",
        "  data_dir,\n",
        "  save_path,\n",
        "  batch_size,\n",
        "  n_workers,\n",
        "  valid_steps,\n",
        "  warmup_steps,\n",
        "  total_steps,\n",
        "  save_steps,\n",
        "):\n",
        "  \"\"\"Main function.\"\"\"\n",
        "  device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "  print(f\"[Info]: Use {device} now!\")\n",
        "\n",
        "  train_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)\n",
        "  train_iterator = iter(train_loader)\n",
        "  print(f\"[Info]: Finish loading data!\",flush = True)\n",
        "\n",
        "  model = Classifier(n_spks=speaker_num).to(device)\n",
        "  criterion = nn.CrossEntropyLoss()\n",
        "  optimizer = AdamW(model.parameters(), lr=1e-3)\n",
        "  scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)\n",
        "  print(f\"[Info]: Finish creating model!\",flush = True)\n",
        "\n",
        "  best_accuracy = -1.0\n",
        "  best_state_dict = None\n",
        "\n",
        "  pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n",
        "\n",
        "  for step in range(total_steps):\n",
        "    # Get data\n",
        "    try:\n",
        "      batch = next(train_iterator)\n",
        "    except StopIteration:\n",
        "      train_iterator = iter(train_loader)\n",
        "      batch = next(train_iterator)\n",
        "\n",
        "    loss, accuracy = model_fn(batch, model, criterion, device)\n",
        "    batch_loss = loss.item()\n",
        "    batch_accuracy = accuracy.item()\n",
        "\n",
        "    # Updata model\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    scheduler.step()\n",
        "    optimizer.zero_grad()\n",
        "    \n",
        "    # Log\n",
        "    pbar.update()\n",
        "    pbar.set_postfix(\n",
        "      loss=f\"{batch_loss:.2f}\",\n",
        "      accuracy=f\"{batch_accuracy:.2f}\",\n",
        "      step=step + 1,\n",
        "    )\n",
        "\n",
        "    # Do validation\n",
        "    if (step + 1) % valid_steps == 0:\n",
        "      pbar.close()\n",
        "\n",
        "      valid_accuracy = valid(valid_loader, model, criterion, device)\n",
        "\n",
        "      # keep the best model\n",
        "      if valid_accuracy > best_accuracy:\n",
        "        best_accuracy = valid_accuracy\n",
        "        best_state_dict = model.state_dict()\n",
        "\n",
        "      pbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n",
        "\n",
        "    # Save the best model so far.\n",
        "    if (step + 1) % save_steps == 0 and best_state_dict is not None:\n",
        "      torch.save(best_state_dict, save_path)\n",
        "      pbar.write(f\"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})\")\n",
        "\n",
        "  pbar.close()\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  main(**parse_args())\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0R2rx3AyHpQ-"
      },
      "source": [
        "# Inference"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pSuI3WY9Fz78"
      },
      "source": [
        "## Dataset of inference"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4evns0055Dsx"
      },
      "source": [
        "import os\n",
        "import json\n",
        "import torch\n",
        "from pathlib import Path\n",
        "from torch.utils.data import Dataset\n",
        "\n",
        "\n",
        "class InferenceDataset(Dataset):\n",
        "  def __init__(self, data_dir):\n",
        "    testdata_path = Path(data_dir) / \"testdata.json\"\n",
        "    metadata = json.load(testdata_path.open())\n",
        "    self.data_dir = data_dir\n",
        "    self.data = metadata[\"utterances\"]\n",
        "\n",
        "  def __len__(self):\n",
        "    return len(self.data)\n",
        "\n",
        "  def __getitem__(self, index):\n",
        "    utterance = self.data[index]\n",
        "    feat_path = utterance[\"feature_path\"]\n",
        "    mel = torch.load(os.path.join(self.data_dir, feat_path))\n",
        "\n",
        "    return feat_path, mel\n",
        "\n",
        "\n",
        "def inference_collate_batch(batch):\n",
        "  \"\"\"Collate a batch of data.\"\"\"\n",
        "  feat_paths, mels = zip(*batch)\n",
        "\n",
        "  return feat_paths, torch.stack(mels)\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oAinHBG1GIWv"
      },
      "source": [
        "## Main funcrion of Inference"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yQaTt7VDHoRI"
      },
      "source": [
        "import json\n",
        "import csv\n",
        "from pathlib import Path\n",
        "from tqdm.notebook import tqdm\n",
        "\n",
        "import torch\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "def parse_args():\n",
        "  \"\"\"arguments\"\"\"\n",
        "  config = {\n",
        "    \"data_dir\": \"./Dataset\",\n",
        "    \"model_path\": \"./model.ckpt\",\n",
        "    \"output_path\": \"./output.csv\",\n",
        "  }\n",
        "\n",
        "  return config\n",
        "\n",
        "\n",
        "def main(\n",
        "  data_dir,\n",
        "  model_path,\n",
        "  output_path,\n",
        "):\n",
        "  \"\"\"Main function.\"\"\"\n",
        "  device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "  print(f\"[Info]: Use {device} now!\")\n",
        "\n",
        "  mapping_path = Path(data_dir) / \"mapping.json\"\n",
        "  mapping = json.load(mapping_path.open())\n",
        "\n",
        "  dataset = InferenceDataset(data_dir)\n",
        "  dataloader = DataLoader(\n",
        "    dataset,\n",
        "    batch_size=1,\n",
        "    shuffle=False,\n",
        "    drop_last=False,\n",
        "    num_workers=8,\n",
        "    collate_fn=inference_collate_batch,\n",
        "  )\n",
        "  print(f\"[Info]: Finish loading data!\",flush = True)\n",
        "\n",
        "  speaker_num = len(mapping[\"id2speaker\"])\n",
        "  model = Classifier(n_spks=speaker_num).to(device)\n",
        "  model.load_state_dict(torch.load(model_path))\n",
        "  model.eval()\n",
        "  print(f\"[Info]: Finish creating model!\",flush = True)\n",
        "\n",
        "  results = [[\"Id\", \"Category\"]]\n",
        "  for feat_paths, mels in tqdm(dataloader):\n",
        "    with torch.no_grad():\n",
        "      mels = mels.to(device)\n",
        "      outs = model(mels)\n",
        "      preds = outs.argmax(1).cpu().numpy()\n",
        "      for feat_path, pred in zip(feat_paths, preds):\n",
        "        results.append([feat_path, mapping[\"id2speaker\"][str(pred)]])\n",
        "  \n",
        "  with open(output_path, 'w', newline='') as csvfile:\n",
        "    writer = csv.writer(csvfile)\n",
        "    writer.writerows(results)\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  main(**parse_args())\n"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}